import os
import math
from shutil import move

import numpy as np
import torch
from torch import nn

from data.create_cliport_programs import utterance2program_old_bdetr, merge_programs
from beauty_detr.bdetr_eval_utils import transform, rescale_bboxes
from models.langevin_dynamics import langevin
from tasks import place_cyan_in_purple
from utils.executor_utils import make_circle, make_line, clamp_boxes_torch

import wandb
import ipdb
st = ipdb.set_trace


XMIN = -1
XMAX = 1
YMIN = -0.5
YMAX = 0.5
ZMIN = -1
ZMAX = 1
SIZE = 0.15

LENS = torch.tensor(
    [XMAX - XMIN - SIZE, YMAX - YMIN - SIZE, ZMAX - ZMIN - SIZE]
) / 2.0


class NS_Transporter(nn.Module):
    """General class for symbolic transporter"""

    def __init__(
            self,
            args,
            parser,
            bdetr_model,
            ebm_dict,
            visualize=False,
            verbose=False
    ):
        """Initialize the modules"""
        super().__init__()
        self.parser = parser
        self.bdetr_model = bdetr_model
        self.ebm_dict = ebm_dict
        self.visualize = visualize
        self.verbose = verbose

        # this is a list of shape models
        # shape_ebms = {"circle": circle_model, "line": line_model}
        self.args = args
        self.device = args.device

        self.score_thresh = 0.5

        # if self.visualize:
        wandb.init(project="NS_Transporter", name="2d_vis")

        self.bad_frames = []

        # ebm bounds are of size (2, 1) while
        # robots's bounds are from (1, 0.5)
        self.robot_bounds_to_ebm_bounds = 2.0

    def forward(self, batch):
        """Forward Pass."""
        module_outputs = {}

        # Run Parser
        programs = self._get_programs(
            batch['raw_utterances'], use_gt=self.args.gt_parsing,
            gt_programs=batch.get('program_lists', None),
            legacy=self.args.legacy
        )
        if self.verbose:
            print(batch['raw_utterances'])
            print(programs)
        module_outputs['pred_programs'] = programs

        batch_outputs = []

        # Run executors, seperately on each program
        for p, program in enumerate(programs):
            outputs = []
            img = batch['initial_frames'][p]
            ground_truth = batch['ground_truths'][p].copy() if batch['ground_truths'] is not None else None
            # st()
            visualize_outputs = []
            columns = []

            for i, op in enumerate(program):
                if op['op'] == 'detect_objects':
                    outputs.append([])
                    continue
                elif op['op'] == 'filter':
                    if self.verbose:
                        print('filter')
                    # run clip here
                    class_label, location = op['concept']

                    # need to merge class label and location
                    # now that we use bdetr
                    assert location == "none", print(location)
                    predictions = self._filter(
                        img, class_label, ground_truths=ground_truth,
                        use_gt=self.args.gt_grounding
                    )
                    if ground_truth is not None:
                        ground_truth.pop(0)
                    outputs.append(predictions)
                    if self.visualize:
                        visualize_outputs.append(
                            self._visualize(
                                img, predictions,
                                caption=f"filter {class_label}, location: {location}",
                                concept='filter'
                            )
                        )
                        columns.append(f'filter_{class_label}_{i}')
                
                elif op['op'] == 'binaryEBM':
                    concepts = []
                    picks = []
                    places = []
                    for j in range(i, len(program)):
                        op = program[j]
                        concept_, _ = op['concept']
                        pick, place = op['inputs']
                        concepts.append(concept_)
                        picks.append(pick)
                        places.append(place)
                    
                    height, width = img.shape[:2]

                    if self.args.gt_ebm:
                        predictions = batch['gt_place_boxes'][p]
                        predictions = clamp_boxes_torch(predictions, height, width)
                        
                        if batch['move_all'][p] and self.args.gt_grounding:
                            pick_boxes = torch.tensor(batch['ground_truths'][p]).to(predictions.device)
                        else:
                            pick_boxes = []
                            place_boxes = []
                            done_pick = []
                            for i_, pick in enumerate(picks):
                                if pick not in done_pick:
                                    pick_boxes.append(outputs[pick])
                                    done_pick.append(pick)
                                    
                            pick_boxes = torch.cat(pick_boxes, 0)
                        place_boxes = predictions

                        pick_boxes, place_boxes = self._make_compatible(pick_boxes, place_boxes)
                        outputs.append((pick_boxes, place_boxes))

                    else:
                        boxes = []
                        count = 0
                        picks_ = []
                        places_ = []
                        concepts_ = []
                        done_idx = []
                        done_ebm = []
                        move_all = batch['move_all'][p]
                        for i, pick in enumerate(picks):
                            place = places[i]
                            pick_boxes_ = outputs[pick]
                            place_boxes_ = outputs[place]

                            if move_all:
                                pick_boxes_ = pick_boxes_[:1]
                                place_boxes_ = place_boxes_[:1]

                            # make the num of pick and place boxes same
                            pick_boxes_, place_boxes_ = self._make_compatible(pick_boxes_, place_boxes_)

                            concepts_ += [concepts[i]] * len(pick_boxes_)
                            if pick not in done_idx:
                                boxes.append(pick_boxes_)
                                picks_ebm_ = [[p_] for p_ in range(count, count + len(pick_boxes_))]
                                picks_ += picks_ebm_
                                count += len(pick_boxes_)
                                done_idx.append(pick)
                                done_ebm.append(picks_ebm_)
                            else:
                                pick_done_idx = done_idx.index(pick)
                                picks_ += done_ebm[pick_done_idx]
                            
                            if place not in done_idx:
                                boxes.append(place_boxes_)
                                place_ebm_ = [[p_] for p_ in range(count, count + len(place_boxes_))]
                                places_ += place_ebm_
                                count += len(place_boxes_)
                                done_idx.append(place)
                                done_ebm.append(place_ebm_)
                            else:
                                place_done_idx = done_idx.index(place)
                                places_ += done_ebm[place_done_idx]
                            # boxes.append(place_boxes_)
                            # places_ += [[p_] for p_ in range(count, count + len(place_boxes_))]
                            # count += len(place_boxes_)
                        boxes = torch.cat(boxes, 0)
                        boxes_ebm = self._pack_boxes_for_ebm(boxes, height, width)


                        try:
                            assert len(picks_) == len(places_) and len(places_) == len(concepts_)
                        except:
                            st()
                        goal_boxes = self._run_ebm(boxes_ebm, picks_, places_, concepts_, move_all=move_all)
                        predictions = self._unpack_boxes_from_ebm(goal_boxes, height, width).reshape(-1, 4)
                        predictions = clamp_boxes_torch(predictions, height, width)

                        if move_all:
                            pick_boxes, place_boxes = boxes, predictions
                        else:
                            pick_boxes = torch.cat(
                                [outputs[pick] for pick in np.unique(picks)]
                            )
                            place_boxes = torch.cat(
                                [predictions[p][None] for pick in np.unique(np.array(picks_).reshape(-1, 1), axis=0) for p in pick]
                            )

                        assert len(pick_boxes) == len(place_boxes)

                        outputs.append((pick_boxes, place_boxes))

                    if self.visualize:
                        visualize_outputs.append(
                            self._visualize(
                                img, torch.round(place_boxes.to(torch.float32)),
                                caption=f"put",
                                concept=f"put"
                            )
                        )
                        columns.append(f'put')
                    break

                elif op['op'] == 'multiAryEBM':
                    # run shape ebm
                    # st()
                    shape_type, _, _, _ = op['concept']
                    if self.verbose:
                        print(op['op'], shape_type)
                    
                    if self.args.gt_grounding:
                        bboxs =  torch.tensor(batch['ground_truths'][p]).to(outputs[op['inputs'][0]].device)
                    else:
                        bboxs = outputs[op['inputs'][0]]
                    height, width = img.shape[:2]
                    if self.args.gt_ebm:
                        predictions = batch['gt_place_boxes'][p]
                    else:
                        boxes_ebm = self._pack_boxes_for_ebm(bboxs, height, width)
                        move_all = True
                        picks_ = [np.arange(len(bboxs))]
                        places_ = [[]]
                        concepts = [shape_type]
                        goal_boxes = self._run_ebm(
                            boxes_ebm, picks_, places_, concepts_, move_all=move_all)
                        predictions = self._unpack_boxes_from_ebm(goal_boxes, height, width).reshape(-1, 4)
                    

                    predictions = clamp_boxes_torch(predictions, height, width)
                    # pick = op['inputs'][0]
                    assert len(predictions) == len(bboxs)
                    outputs.append((bboxs, predictions.to(torch.float32)))

                    if self.visualize:
                        visualize_outputs.append(
                            self._visualize(
                                img, predictions,
                                caption=f"{shape_type}",
                                concept=f'{shape_type}'
                            )
                        )
                        columns.append(f'{shape_type}')

                else:
                    assert False, f'unknown op: {op}'

            batch_outputs.append(outputs[-1])

        return batch_outputs, module_outputs, visualize_outputs, columns

    @torch.no_grad()
    def _get_programs(self, phrases, use_gt=False,
                      gt_programs=None, legacy=False):
        """Run Seq2Tree Parser, phrases is a list of str."""
        if use_gt:
            return gt_programs
        # support for old sentences for circle and line
        # if legacy:
        #     programs = []
        #     for sent in phrases:
        #         if 'line' in sent or 'circle' in sent:
        #             program = [utterance2program_old_bdetr(sent)]
        #         else:
        #             _, program = self.parser(
        #                 [sent], None,
        #             teacher_forcing=False, compute_loss=False
        #             )
        #         programs.extend(program)
        # else:
        if not (" and " in phrases[0] and " to the " in phrases[0]):
            _, programs = self.parser(
                phrases, None,
                teacher_forcing=False, compute_loss=False
            )
        else:
            phrases_ = phrases[0].split(" and ")
            _, programs = self.parser(
                phrases_, None,
                teacher_forcing=False, compute_loss=False
            )
            programs = [merge_programs(programs)]
            
        return programs

    @torch.no_grad()
    def _filter(
        self, image, caption,
        ground_truths=None, use_gt=False
    ):
        if use_gt:
            return self._filter_oracle(ground_truths, caption)
        else:
            return self._filter_pred(image, caption)

    @torch.no_grad()
    def _filter_pred(self, image, caption):
        """Filters objects mentioned in the caption from image

        Args:
            image ([HXWX3]): image in unnormalized pixel format (0-256)
            caption ([string]): language description of object

        Returns:
            Bounding boxes ([NX4]): bboxes satisfying caption (x1, y1, x2, y2)
        """
        # mean-std normalize the input image (batch-size: 1)
        img = (image / 255.).permute(2, 0, 1)
        img = transform(img).unsqueeze(0)
        caption = f"find {caption}"
        # propagate through the model
        memory_cache = self.bdetr_model(img, [caption], encode_and_save=True)
        outputs = self.bdetr_model(img, [caption], encode_and_save=False, memory_cache=memory_cache)
        # keep only predictions with 0.7+ confidence
        probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu()
        keep = (probas > 0.7).cpu()
        if keep.sum() == 0:
            keep = (probas == probas.max())

        # sort boxes by confidence
        scores = probas.cpu()[keep]
        boxes = outputs['pred_boxes'].cpu()[0, keep]
        sorted_scores_boxes = sorted(
            zip(scores.tolist(), boxes.tolist()), reverse=True)
        _, sorted_boxes = zip(*sorted_scores_boxes)
        sorted_boxes = torch.cat([torch.as_tensor(x).view(1, 4) for x in sorted_boxes])

        #  convert boxes from [0; 1] to image scales
        _, _, h, w = img.shape
        bboxes_scaled = rescale_bboxes(
            sorted_boxes,
            img_h=h, img_w=w)
        return bboxes_scaled

    @torch.no_grad()
    def _filter_oracle(self, ground_truths, caption):
        """Filters objects mentioned in the caption from image

        Args:
            image ([HXWX3]): image in unnormalized pixel format (0-256)
            ground_truths: dict with key as clas and value as boxes
            caption ([string]): language description of object

        Returns:
            Bounding boxes ([NX4]): bboxes satisfying caption (x1, y1, x2, y2)
        """
        return torch.tensor(ground_truths[0])[None]

    def _multiaryEBM(self, boxes, height, width, shape_type,
               size=None, pos=None):
        """
        Bboxs: (B, N, 4) representation [cx, cy, w, h]
                x lies in [-0.5, 0.5]; y lies in [-1, 1]
        """
        model = self.ebm_dict[shape_type]

        center = torch.tensor([self.loc[shape_type][pos]]).to(self.device)
        diameter = torch.tensor([self.size[shape_type][size]]).to(self.device)

        diameter *= self.robot_bounds_to_ebm_bounds

        if shape_type == 'line':
            diameter *= len(boxes)

        boxes = self._pack_boxes_for_ebm(boxes, height, width).to(self.device)
        goal_boxes, vis_boxes = model.run([boxes], [center, diameter])
        # writer = SummaryWriter(f"runs/circle")
        # sw = utils.improc.Summ_writer(
        #     writer=writer,
        #     global_step=0,
        #     log_freq=1,
        #     fps=15,
        #     just_gif=True
        # )
        # vis_neg_ld_images = torch.stack([
        #     plot_relations_2d(
        #         visc[0].detach().cpu().numpy(),
        #         None
        #     )
        #     for visc in vis_boxes
        # ], dim=1)
        # sw.summ_gif('sample/start_to_goal_gif', vis_neg_ld_images)
        return self._unpack_boxes_from_ebm(goal_boxes[0], height, width)

    def _binaryEBM(self, input_obj, goal_obj, concept):
        """
        input_obj: list of obj that needs to be moved
        goal_obj: the object where to move
        """
        st()
        model = self.ebm_dict[concept]
        # preprocess inputs for ebms
        moved_objs = model.run(input_obj, goal_obj, test=True)
        return moved_objs

    def _run_ebm(self, boxes, subj, obj, rel, move_all):
        """
        boxes: np.array, (n_relevant_boxes, 4)
        subj: list of lists, rel boxes, e.g. [[1], [0, 9]], points to boxes
        obj: list of lists, ref boxes, e.g. [[2], []], points to boxes
        rel: list of str, relation names, e.g. ["right", "circle"]
        move_all: bool, whether to move all (True) or fix object box (False)

        subj, obj and rel need to have the same number of elements
        if there's no ref object (e.g. in shapes), the corresponding
        element is []

        returns boxes, np.array with same shape as the input boxes
        """
        assert len(subj) == len(obj)
        assert len(obj) == len(rel)
        boxes, _ = langevin(self.ebm_dict, boxes, subj, obj, rel, move_all)
        return boxes.detach().cpu().numpy()

    @staticmethod
    def _make_compatible(pick_boxes_, place_boxes_):
        # make the num of pick and place boxes same
        if len(place_boxes_) < len(pick_boxes_):
            repeat_num = math.ceil(
            (len(pick_boxes_) / float(len(place_boxes_))))
            place_boxes_ = place_boxes_.repeat(repeat_num, 1)

        if len(place_boxes_) > len(pick_boxes_):
            place_boxes_ = place_boxes_[:len(pick_boxes_)]
        return pick_boxes_, place_boxes_


    @staticmethod
    def _crop_img_inside_bbox(img, bbox):
        """
        Inputs:
            img: H, W, 3
            bbox: N, 4 [x1, y1, x2, y2]
        Outputs:
            img_crops: N, h, w, 3
        """
        img_patches = []
        for box in bbox:
            try:
                img_patches.append(img[box[1]: box[3], box[0]: box[2]])
            except Exception as e:
                print(box, e)
        return img_patches

    @staticmethod
    def _visualize(img, detections=None, concept='scene', caption=None):
        all_boxes = []
        if detections is not None:
            for b_i, box in enumerate(detections):
                box_data = {"position": {
                    "minX": box[0].item(),
                    "maxX": box[2].item(),
                    "minY": box[1].item(),
                    "maxY": box[3].item()
                    },
                    "class_id": 0,
                    "domain": "pixel",
                }
                all_boxes.append(box_data)

        box_image = wandb.Image(
            img[..., :3].cpu().numpy(),
            caption=caption,
            boxes={"predictions": {
                "box_data": all_boxes,
                "class_labels": {0:"obj", 1:"rand"}
                }} if detections is not None else None,
            classes=wandb.Classes([{"name": 0, 'id': 'obj'}])
            )
        if detections is None:
            print("bad")
        return box_image

    def _pack_boxes_for_ebm(self, boxes, height, width):
        # boxes are (x1, y1, x2, y2)
        # normalize
        boxes = torch.clone(boxes).to(torch.float32)
        boxes[..., (0, 2)] /= width
        boxes[..., (1, 3)] /= height

        # to center-size
        boxes = torch.stack((
            (boxes[..., 0] + boxes[..., 2]) * 0.5,
            (boxes[..., 1] + boxes[..., 3]) * 0.5,
            boxes[..., 2] - boxes[..., 0],
            boxes[..., 3] - boxes[..., 1]
        ), -1)

        # scale
        boxes[..., 0] = boxes[..., 0]*(XMAX - XMIN) + XMIN
        boxes[..., 1] = boxes[..., 1]*(YMAX - YMIN) + YMIN
        return boxes.cpu().numpy() # batch

    def _unpack_boxes_from_ebm(self, boxes, height, width):
        # boxes are (x, y, W, H) normalized and rescaled
        # boxes = boxes.squeeze(0)  # un-batch
        # un-scale
        boxes[..., 0] = (boxes[..., 0] - XMIN) / (XMAX - XMIN)
        boxes[..., 1] = (boxes[..., 1] - YMIN) / (YMAX - YMIN)

        # to (x1, y1, x2, y2)
        boxes = np.concatenate((
            boxes[..., :2] - boxes[..., 2:] / 2,
            boxes[..., :2] + boxes[..., 2:] / 2
        ), -1)

        # un-normalize
        boxes[..., (0, 2)] *= width
        boxes[..., (1, 3)] *= height
        return torch.from_numpy(boxes)

    @staticmethod
    def _get_pick_and_place_boxes(boxes, predictions):
        # boxes:  (N X 4) tensor
        # predictions: (N X 4) tensor
        # return pick_boxes and place_boxes (M X 4)
        st()
        moved = ~np.isclose(boxes, predictions).all(-1)
        pick_boxes = boxes[moved]
        place_boxes = boxes[moved]
        return pick_boxes, place_boxes 

    @staticmethod
    def _gt_binary_ebm(pick_boxes, place_boxes, concept=None):
        if len(place_boxes) < len(pick_boxes):
            repeat_num = math.ceil(
                (len(pick_boxes) / float(len(place_boxes))))
            place_boxes = place_boxes.repeat(repeat_num, 1)

        
        output_boxes = []
        for pick_box, place_box in zip(pick_boxes, place_boxes):
            xc, yc = (place_box[2:] + place_box[:2]) / 2.0
            w, h = pick_box[2:] - pick_box[:2]
            output_boxes.append(
                [
                    xc - (w / 2.0), yc - (h / 2.0),
                    xc + (w / 2.0), yc + (h / 2.0)
                ]
            )
        output_boxes = torch.tensor(output_boxes).reshape(-1, 4)
        return output_boxes

    def _gt_multiary_ebm(self, boxes, shape_type, size, pos):
        center = self.loc[shape_type][pos]
        diameter = self.size[shape_type][size]
        boxes = boxes.to(self.device) 

        diameter *= self.robot_bounds_to_ebm_bounds

        if shape_type == "line":
            output_boxes = make_line(boxes, center, diameter)
        elif shape_type == 'circle':
            output_boxes = make_circle(boxes, center, diameter)
        else:
            assert False, f"{shape_type} not supported"
        output_boxes = output_boxes.reshape(-1, 4)
        return output_boxes

    @staticmethod
    def _save_data(outputs, pick, place, batch, p):
        data = {
            "pick": outputs[pick].cpu().numpy(),
            "put": outputs[place].cpu().numpy(),
            "initial_frame_path": batch['initial_frame_paths'][p],
            "goal_frame_path": batch['goal_frame_paths'][p],
            "utterance": batch['raw_utterances'][p]
        }
        file_split = batch['initial_frame_paths'][p].split('/')
        root = '/'.join(file_split[:-2])
        name = file_split[-1].split('.')[0]
        dir = os.path.join(root, 'ebm_preds')
        file_name = os.path.join(dir, f"{name}.npy")
        os.makedirs(dir, exist_ok=True)
        print(file_name)
        np.save(file_name, data)
